#!/usr/bin/env python3
"""
Phase 1: Data Preparation for Extensions Paper
Loads full_panel.csv, downloads WDI trade balance, computes derived variables.
Output: extensions/data/processed/extensions_panel.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import requests
import time
import io

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"

OUT_DATA = PROJECT_DIR / "data" / "processed"
OUT_DATA.mkdir(parents=True, exist_ok=True)


# ── Download WDI trade balance ─────────────────────────────────────────────

def download_trade_balance(force=False):
    """
    Download NE.RSB.GNFS.ZS (External balance on goods & services, % GDP)
    from World Bank WDI API.
    """
    cache_path = OUT_DATA / "wdi_trade_balance_raw.csv"
    if cache_path.exists() and not force:
        print(f"  Using cached trade balance data: {cache_path}")
        return pd.read_csv(cache_path)

    print("  Downloading WDI trade balance (NE.RSB.GNFS.ZS)...")
    indicator = "NE.RSB.GNFS.ZS"
    base_url = "https://api.worldbank.org/v2/country/all/indicator"
    all_rows = []
    page = 1

    while True:
        url = (f"{base_url}/{indicator}?format=json&per_page=1000"
               f"&date=1970:2024&page={page}")
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
        data = resp.json()

        if len(data) < 2 or data[1] is None:
            break

        for obs in data[1]:
            if obs["value"] is not None and obs.get("countryiso3code"):
                all_rows.append({
                    "iso3": obs["countryiso3code"],
                    "year": int(obs["date"]),
                    "trade_balance_gdp": float(obs["value"]),
                })

        total_pages = data[0]["pages"]
        print(f"    Page {page}/{total_pages} ({len(all_rows)} obs so far)")
        if page >= total_pages:
            break
        page += 1
        time.sleep(0.5)

    df = pd.DataFrame(all_rows)
    df.to_csv(cache_path, index=False)
    print(f"  Downloaded {len(df):,} obs for {df['iso3'].nunique()} countries")
    return df


# ── Main ───────────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 1: Data Preparation")
    print("=" * 70)

    # Load full panel
    fp_path = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    print(f"\n  Loading {fp_path}")
    fp = pd.read_csv(fp_path)
    fp = fp[fp["year"] <= 2024].copy()
    print(f"  Full panel: {fp['iso3'].nunique()} countries, {len(fp):,} obs, "
          f"{fp['year'].min()}-{fp['year'].max()}")

    # Download trade balance
    tb = download_trade_balance()

    # Merge trade balance
    fp = fp.merge(tb, on=["iso3", "year"], how="left")
    n_tb = fp["trade_balance_gdp"].notna().sum()
    print(f"\n  Trade balance coverage: {n_tb:,} / {len(fp):,} obs "
          f"({n_tb/len(fp)*100:.1f}%)")

    # ── Compute derived variables ──────────────────────────────────────────

    # Income balance = CA - trade balance (primary + secondary income net)
    fp["income_balance_gdp"] = fp["ca_gdp"] - fp["trade_balance_gdp"]

    # Verify accounting identity where both available
    both = fp.dropna(subset=["ca_gdp", "trade_balance_gdp"])
    resid = (both["ca_gdp"] - both["trade_balance_gdp"]
             - both["income_balance_gdp"])
    print(f"  Accounting identity check: max |CA - TB - IB| = {resid.abs().max():.2e}")

    # First difference of kaopen within country
    fp = fp.sort_values(["iso3", "year"])
    fp["d_kaopen"] = fp.groupby("iso3")["kaopen"].diff()

    # Liberalization indicator
    fp["kaopen_liberalized"] = (fp["d_kaopen"] > 0).astype(float)
    fp.loc[fp["d_kaopen"].isna(), "kaopen_liberalized"] = np.nan

    # Fiscal × demographic interactions
    for z in ["Z_1", "Z_2", "Z_3"]:
        fp[f"fiscal_x_{z}"] = fp["fiscal_bal_gdp"] * fp[z]

    fp["fiscal_x_old_dep"] = fp["fiscal_bal_gdp"] * fp["old_dep"]

    # NFA × income balance interaction (for creditor/debtor analysis)
    fp["nfa_x_income_bal"] = fp["nfa_gdp_lag"] * fp["income_balance_gdp"]

    # Time trend (centered)
    fp["time_trend"] = fp["year"] - fp["year"].median()
    fp["oadr_x_trend"] = fp["old_dep"] * fp["time_trend"]
    fp["old_dep_sq"] = fp["old_dep"] ** 2

    # Demographic terciles
    tercile_bounds = fp.groupby("year")["old_dep"].quantile([1/3, 2/3])
    tercile_dict = tercile_bounds.unstack()
    fp["demo_tercile"] = np.nan
    for year in fp["year"].unique():
        if year not in tercile_dict.index:
            continue
        mask = fp["year"] == year
        t1, t2 = tercile_dict.loc[year, 1/3], tercile_dict.loc[year, 2/3]
        fp.loc[mask & (fp["old_dep"] <= t1), "demo_tercile"] = 1  # early
        fp.loc[mask & (fp["old_dep"] > t1) & (fp["old_dep"] <= t2), "demo_tercile"] = 2  # mid
        fp.loc[mask & (fp["old_dep"] > t2), "demo_tercile"] = 3  # late

    # Savings rate (for Feldstein-Horioka)
    # gross_national_savings_gdp and gross_investment_gdp already in panel
    fp["savings_gdp"] = fp["gross_national_savings_gdp"]
    fp["investment_gdp"] = fp["gross_investment_gdp"]

    # ── Summary statistics ─────────────────────────────────────────────────
    new_vars = ["trade_balance_gdp", "income_balance_gdp", "d_kaopen",
                "kaopen_liberalized", "fiscal_x_Z_1", "fiscal_x_old_dep",
                "demo_tercile", "old_dep_sq"]
    print(f"\n  New variables summary:")
    for v in new_vars:
        if v in fp.columns:
            s = fp[v].dropna()
            print(f"    {v:<25s}: N={len(s):>5,}  mean={s.mean():>8.3f}  "
                  f"std={s.std():>8.3f}")

    # ── Save ───────────────────────────────────────────────────────────────
    out_path = OUT_DATA / "extensions_panel.csv"
    fp.to_csv(out_path, index=False)
    print(f"\n  Saved: {out_path}")
    print(f"  Shape: {fp.shape[0]:,} rows × {fp.shape[1]} columns")
    print(f"  Countries: {fp['iso3'].nunique()}")

    return fp


if __name__ == "__main__":
    main()
